import json
import os

def merge_jsonl_files(train_datasets, output_dir='./data/math'):

    os.makedirs(output_dir, exist_ok=True)
    
    train_data = []
    test_data = []
    

    for dataset in train_datasets:
        train_file = f'./data/math_{dataset}/math_{dataset}_train.jsonl'
        test_file = f'./data/math_{dataset}/math_{dataset}_test.jsonl'
        
        if os.path.exists(train_file):
            with open(train_file, 'r', encoding='utf-8') as f:
                for line in f:
                    train_data.append(json.loads(line.strip()))
        else:
            print(f"Warning: {train_file} not found")
            
        if os.path.exists(test_file):
            with open(test_file, 'r', encoding='utf-8') as f:
                for line in f:
                    test_data.append(json.loads(line.strip()))
        else:
            print(f"Warning: {test_file} not found")
    
    with open(os.path.join(output_dir, 'math_train.jsonl'), 'w', encoding='utf-8') as f:
        for item in train_data:
            f.write(json.dumps(item) + '\n')
            
    with open(os.path.join(output_dir, 'math_test.jsonl'), 'w', encoding='utf-8') as f:
        for item in test_data:
            f.write(json.dumps(item) + '\n')
    
    print(f"Merged {len(train_data)} training examples and {len(test_data)} test examples")

train_datasets = [
    'algebra',
    'prealgebra',
    'counting_and_probability',
    'geometry',
    'precalculus',
    'number_theory',
    'intermediate_algebra'
]

merge_jsonl_files(train_datasets)